{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch2onnx `SequenceConstruct`\n",
    "\n",
    "参考：[SequenceConstruct](https://onnx.ai/onnx/operators/onnx__SequenceConstruct.html)\n",
    "\n",
    "`SequenceConstruct`: 构建包含 `inputs` 张量的张量序列。`inputs` 中的所有张量必须具有相同的数据类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend\n"
     ]
    }
   ],
   "source": [
    "%cd ../../..\n",
    "import set_env\n",
    "from d2py.utils.file import mkdir\n",
    "root_dir = \".temp\"\n",
    "mkdir(root_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu)):\n",
      "  %/Constant_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/Constant\"](), scope: PrimModule:: # /tmp/ipykernel_2083243/2703021558.py:6:19\n",
      "  %/Mul_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], device=cpu) = onnx::Mul[onnx_name=\"/Mul\"](%data, %/Constant_output_0), scope: PrimModule:: # /tmp/ipykernel_2083243/2703021558.py:6:19\n",
      "  %output : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], device=cpu)[] = onnx::SequenceConstruct[onnx_name=\"/SequenceConstruct\"](%data, %/Mul_output_0), scope: PrimModule::\n",
      "  return (%output)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.onnx import OperatorExportTypes, utils\n",
    "class PrimModule(torch.jit.ScriptModule):\n",
    "    @torch.jit.script_method\n",
    "    def forward(self, x):\n",
    "        return [x, x*2]\n",
    "\n",
    "model = PrimModule()\n",
    "model.eval()\n",
    "\n",
    "shape = 1, 3, 8, 8\n",
    "x = torch.rand(*shape, dtype=torch.float32, requires_grad=False)\n",
    "model = PrimModule()\n",
    "# 导出模型\n",
    "output_name = \"SequenceConstruct\"\n",
    "utils.export(\n",
    "    model,               # torch 模型\n",
    "    x,                         # 模型输入或者对于多个输入，使用元组\n",
    "    f\"{root_dir}/{output_name}.onnx\",               # 模型保存的位置（可以是文件或类似文件的对象）\n",
    "    export_params=True,        # 将训练后的参数权重存储在模型文件内\n",
    "    opset_version=17,          # 导出模型的 ONNX 版本\n",
    "    do_constant_folding=True,  # 是否执行常量折叠以进行优化\n",
    "    input_names = ['data'],    # 模型的输入名称\n",
    "    output_names = ['output'], # 模型的输出名称\n",
    "    keep_initializers_as_inputs=True,\n",
    "    # export_modules_as_functions=True,\n",
    "    verbose=True,\n",
    "    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,\n",
    "    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴\n",
    "    #               'output' : {0 : 'batch_size'}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![模型结构](images/SequenceConstruct.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Mul<span style=\"color: #AA22FF; font-weight: bold\">.</span>data:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> (Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32], Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> multiply(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data, <span style=\"color: #008000\">2</span>f <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>float32 span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Constant:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Mul:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  (<span style=\"color: #AA22FF; font-weight: bold\">%</span>data, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>(Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32], Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32]) span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>SequenceConstruct:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import onnx\n",
    "import tvm\n",
    "from tvm import relay\n",
    "onnx_model = onnx.load(f\"{root_dir}/{output_name}.onnx\")\n",
    "mod, params = relay.frontend.from_onnx(onnx_model, {\"data\": shape}, freeze_params=True)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
